import cv2
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import path_finder as pf
from PIL import Image
from scipy.signal import get_window
import os

# Load your image
image_path = pf.choose_path()
directory_path, image_name = os.path.split(image_path)
image_files = sorted([f for f in os.listdir(directory_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))])

Density_r = []
Bias = []
i = 0
# get the bias
for Image in image_files:
    #bias = Image[5:8]
    #bias = Image[3:7]
    bias = Image[2:6]
    Bias.append(float(bias))

# get the image and FFT spectrum
for Image in image_files:
    image0 = cv2.imread(directory_path + "/" + Image, 0) # size 1108 x 665
    '''# cropping the iamge
    start_x, end_x = 0, 500
    start_y, end_y = 400, 660
    image = image0[start_x:end_x, start_y:end_y]'''
    image = image0
    # Define the window function (e.g., Hamming, Hanning, Blackman, etc.)
    #window_type = 'nuttall'  # Change this to your desired window type
    window_type = 'hamming'  # Change this to your desired window type
    window_width = get_window(window_type, image.shape[1])
    window_height = get_window(window_type, image.shape[0])
    window = np.outer(window_height, window_width)
    windowed_image = image*window

    y_size = len(image[:,1])
    x_size = len(image[1,:])
    fft_x = (-1000*np.pi / x_size)*x_size/y_size
    fft_y = 1000*np.pi / y_size
    #fft_array = np.absolute(np.fft.fft2(np.flipud(image)))
    fft_array = np.absolute(np.fft.fft2(np.flipud(windowed_image)))
    fft_array = 20*np.log(np.fliplr(np.fft.fftshift(fft_array)))
    # Get the dimensions of the image
    height, width = image.shape
    #size_x = 50
    #size_y = int(np.round(size_x*height/width))
    size_y = 50
    size_x = int(np.round(size_y*width/height))
    # Define frequency ranges to zoom in (in Hz)
    freq_range_x_min = width//2-size_x  # Minimum frequency for x-axis
    freq_range_x_max = width//2+size_x  # Maximum frequency for x-axis
    freq_range_y_min = height//2-size_y  # Minimum frequency for y-axis
    freq_range_y_max = height//2+size_y  # Maximum frequency for y-axis
    max_fft = np.max(fft_array[freq_range_y_min:freq_range_y_max, freq_range_x_min:freq_range_x_max])
    min_fft = np.min(fft_array[freq_range_y_min:freq_range_y_max, freq_range_x_min:freq_range_x_max])
    mean_fft = np.mean(fft_array[freq_range_y_min:freq_range_y_max, freq_range_x_min:freq_range_x_max])

    fft_roi = fft_array[freq_range_y_min-2*size_y:freq_range_y_max+2*size_y, freq_range_x_min-2*size_x:freq_range_x_max+2*size_x]
    height_roi, width_roi = fft_roi.shape

    '''# Display the original image and its FFT result
    plt.figure(figsize=(8, 6))
    plt.subplot(121)
    plt.imshow(image, cmap='viridis')
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(122)
    plt.imshow(fft_array[freq_range_y_min:freq_range_y_max, freq_range_x_min:freq_range_x_max], extent=[fft_x, -fft_x, -fft_y, fft_y], origin='lower', cmap="Greys_r", vmin=mean_fft, vmax=(0.7*max_fft+0.3*mean_fft))
    plt.title('FFT Result (Magnitude Spectrum)')
    plt.axis('on')

    plt.show()'''

    x0, y0 = width_roi // 2, height_roi // 2  # fft center point
    def distance(p1, p2):  # calculate the distance
        x1, y1 = p1
        x2, y2 = p2
        dist = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
        return dist

    r = 10
    dr = 1
    k = height / width

    # calculate the density while r = r0
    def ro_n(r0, dr):
        points = []
        for x in np.arange(width_roi):
            for y in np.arange(height_roi):
                if (distance((k * x, y), (k * x0, y0)) >= r0) & (distance((k * x, y), (k * x0, y0)) < (r0 + dr)):
                    points.append([x, y])
        # return points
        if points:
            N = sum(fft_roi[point[1], point[0]] for point in points)
            n_r = N / len(points)
        else:
            n_r = 0
        return n_r

    # get the k distribution
    density_r = []
    r = np.arange(30)
    for ri in r:
        density_r.append(ro_n(ri, 1))
    density_r = np.array(density_r)

    if i == 0:
        Density_r = density_r
    elif i == 1:
        Density_r = np.append([Density_r], [density_r], axis=0)
    else:
        Density_r = np.append(Density_r, [density_r], axis=0)
    i += 1

'''Density = Density_r[:,5:]
max, min = np.max(Density), np.min(Density)
plt.figure(figsize=(6, 6))
plt.imshow(Density_r)
plt.clim(vmin=min, vmax=max)
plt.show()'''

# plot the FFT spectrum
plt.figure(figsize=(6, 4))
k = 2*np.pi*0.011*np.linspace(0,30,30)
#bias = np.array(Bias)
bias = np.array([-6, -7, -8, -9, -10, -11, -12])
#bias = np.append(bias-0.05, bias[-1]+0.05) # make it to meshgrid form
#bias = np.append(10*bias+0.5, 10*bias[-1]-0.5)
K, B = np.meshgrid(k,bias)
plt.pcolormesh(K, B, Density_r, cmap='viridis', shading='auto')
plt.xlabel('k(nm$^-$$^1$)')
plt.ylabel("V$_B$$_G$(V)")
#plt.ylabel("V$_b$$_i$$_a$$_s$(V)")
plt.title('FFT spectrum')
#plt.title('FFT spectrum')
plt.colorbar(label='FFT intensity')
Density = Density_r[:,5:]
max, min = np.max(Density), np.min(Density)
plt.clim(vmin=(0.1*max+0.9*min), vmax=max)
plt.gca().set_ylim(-6,-12)
plt.show()

## k size 70 um^-1 (process see ipad scratch)
